{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Fusing Convolution and Batch Norm using Custom Function\n\nFusing adjacent convolution and batch norm layers together is typically an\ninference-time optimization to improve run-time. It is usually achieved\nby eliminating the batch norm layer entirely and updating the weight\nand bias of the preceding convolution [0]. However, this technique is not\napplicable for training models.\n\nIn this tutorial, we will show a different technique to fuse the two layers\nthat can be applied during training. Rather than improved runtime, the\nobjective of this optimization is to reduce memory usage.\n\nThe idea behind this optimization is to see that both convolution and\nbatch norm (as well as many other ops) need to save a copy of their input\nduring forward for the backward pass. For large\nbatch sizes, these saved inputs are responsible for most of your memory usage,\nso being able to avoid allocating another input tensor for every\nconvolution batch norm pair can be a significant reduction.\n\nIn this tutorial, we avoid this extra allocation by combining convolution\nand batch norm into a single layer (as a custom function). In the forward\nof this combined layer, we perform normal convolution and batch norm as-is,\nwith the only difference being that we will only save the inputs to the convolution.\nTo obtain the input of batch norm, which is necessary to backward through\nit, we recompute convolution forward again during the backward pass.\n\nIt is important to note that the usage of this optimization is situational.\nThough (by avoiding one buffer saved) we always reduce the memory allocated at\nthe end of the forward pass, there are cases when the *peak* memory allocated\nmay not actually be reduced. See the final section for more details.\n\nFor simplicity, in this tutorial we hardcode `bias=False`, `stride=1`, `padding=0`, `dilation=1`,\nand `groups=1` for Conv2D. For BatchNorm2D, we hardcode `eps=1e-3`, `momentum=0.1`,\n`affine=False`, and `track_running_statistics=False`. Another small difference\nis that we add epsilon in the denomator outside of the square root in the computation\nof batch norm.\n\n[0] https://nenadmarkus.com/p/fusing-batchnorm-and-conv/\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Backward Formula Implementation for Convolution\nImplementing a custom function requires us to implement the backward\nourselves. In this case, we need both the backward formulas for Conv2D\nand BatchNorm2D. Eventually we'd chain them together in our unified\nbackward function, but below we first implement them as their own\ncustom functions so we can validate their correctness individually\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom torch.autograd.function import once_differentiable\nimport torch.nn.functional as F\n\ndef convolution_backward(grad_out, X, weight):\n    grad_input = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)\n    grad_X = F.conv_transpose2d(grad_out, weight)\n    return grad_X, grad_input\n\nclass Conv2D(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, X, weight):\n        ctx.save_for_backward(X, weight)\n        return F.conv2d(X, weight)\n\n    # Use @once_differentiable by default unless we intend to double backward\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_out):\n        X, weight = ctx.saved_tensors\n        return convolution_backward(grad_out, X, weight)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "When testing with gradcheck, it is important to use double precision\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)\nX = torch.rand(10, 3, 7, 7, requires_grad=True, dtype=torch.double)\ntorch.autograd.gradcheck(Conv2D.apply, (X, weight))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Backward Formula Implementation for Batch Norm\nBatch Norm has two modes: training and eval mode. In training mode\nthe sample statistics are a function of the inputs. In eval mode,\nwe use the saved running statistics, which are not a function of the inputs.\nThis makes non-training mode's backward significantly simpler. Below\nwe implement and test only the training mode case.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def unsqueeze_all(t):\n    # Helper function to unsqueeze all the dimensions that we reduce over\n    return t[None, :, None, None]\n\ndef batch_norm_backward(grad_out, X, sum, sqrt_var, N, eps):\n    # We use the formula: out = (X - mean(X)) / (sqrt(var(X)) + eps)\n    # in batch norm 2d's forward. To simplify our derivation, we follow the\n    # chain rule and compute the gradients as follows before accumulating\n    # them all into a final grad_input.\n    #  1) 'grad of out wrt var(X)' * 'grad of var(X) wrt X'\n    #  2) 'grad of out wrt mean(X)' * 'grad of mean(X) wrt X'\n    #  3) 'grad of out wrt X in the numerator' * 'grad of X wrt X'\n    # We then rewrite the formulas to use as few extra buffers as possible\n    tmp = ((X - unsqueeze_all(sum) / N) * grad_out).sum(dim=(0, 2, 3))\n    tmp *= -1\n    d_denom = tmp / (sqrt_var + eps)**2  # d_denom = -num / denom**2\n    # It is useful to delete tensors when you no longer need them with `del`\n    # For example, we could've done `del tmp` here because we won't use it later\n    # In this case, it's not a big difference because tmp only has size of (C,)\n    # The important thing is avoid allocating NCHW-sized tensors unnecessarily\n    d_var = d_denom / (2 * sqrt_var)  # denom = torch.sqrt(var) + eps\n    # Compute d_mean_dx before allocating the final NCHW-sized grad_input buffer\n    d_mean_dx = grad_out / unsqueeze_all(sqrt_var + eps)\n    d_mean_dx = unsqueeze_all(-d_mean_dx.sum(dim=(0, 2, 3)) / N)\n    # d_mean_dx has already been reassigned to a C-sized buffer so no need to worry\n\n    # (1) unbiased_var(x) = ((X - unsqueeze_all(mean))**2).sum(dim=(0, 2, 3)) / (N - 1)\n    grad_input = X * unsqueeze_all(d_var * N)\n    grad_input += unsqueeze_all(-d_var * sum)\n    grad_input *= 2 / ((N - 1) * N)\n    # (2) mean (see above)\n    grad_input += d_mean_dx\n    # (3) Add 'grad_out / <factor>' without allocating an extra buffer\n    grad_input *= unsqueeze_all(sqrt_var + eps)\n    grad_input += grad_out\n    grad_input /= unsqueeze_all(sqrt_var + eps)  # sqrt_var + eps > 0!\n    return grad_input\n\nclass BatchNorm(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, X, eps=1e-3):\n        # Don't save keepdim'd values for backward\n        sum = X.sum(dim=(0, 2, 3))\n        var = X.var(unbiased=True, dim=(0, 2, 3))\n        N = X.numel() / X.size(1)\n        sqrt_var = torch.sqrt(var)\n        ctx.save_for_backward(X)\n        ctx.eps = eps\n        ctx.sum = sum\n        ctx.N = N\n        ctx.sqrt_var = sqrt_var\n        mean = sum / N\n        denom = sqrt_var + eps\n        out = X - unsqueeze_all(mean)\n        out /= unsqueeze_all(denom)\n        return out\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_out):\n        X, = ctx.saved_tensors\n        return batch_norm_backward(grad_out, X, ctx.sum, ctx.sqrt_var, ctx.N, ctx.eps)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Testing with gradcheck\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "a = torch.rand(1, 2, 3, 4, requires_grad=True, dtype=torch.double)\ntorch.autograd.gradcheck(BatchNorm.apply, (a,), fast_mode=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Fusing Convolution and BatchNorm\nNow that the bulk of the work has been done, we can combine\nthem together. Note that in (1) we only save a single buffer\nfor backward, but this also means we recompute convolution forward\nin (5). Also see that in (2), (3), (4), and (6), it's the same\nexact code as the examples above.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class FusedConvBN2DFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, X, conv_weight, eps=1e-3):\n        assert X.ndim == 4  # N, C, H, W\n        # (1) Only need to save this single buffer for backward!\n        ctx.save_for_backward(X, conv_weight)\n\n        # (2) Exact same Conv2D forward from example above\n        X = F.conv2d(X, conv_weight)\n        # (3) Exact same BatchNorm2D forward from example above\n        sum = X.sum(dim=(0, 2, 3))\n        var = X.var(unbiased=True, dim=(0, 2, 3))\n        N = X.numel() / X.size(1)\n        sqrt_var = torch.sqrt(var)\n        ctx.eps = eps\n        ctx.sum = sum\n        ctx.N = N\n        ctx.sqrt_var = sqrt_var\n        mean = sum / N\n        denom = sqrt_var + eps\n        # Try to do as many things in-place as possible\n        # Instead of `out = (X - a) / b`, doing `out = X - a; out /= b`\n        # avoids allocating one extra NCHW-sized buffer here\n        out = X - unsqueeze_all(mean)\n        out /= unsqueeze_all(denom)\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_out):\n        X, conv_weight, = ctx.saved_tensors\n        # (4) Batch norm backward\n        # (5) We need to recompute conv\n        X_conv_out = F.conv2d(X, conv_weight)\n        grad_out = batch_norm_backward(grad_out, X_conv_out, ctx.sum, ctx.sqrt_var,\n                                       ctx.N, ctx.eps)\n        # (6) Conv2d backward\n        grad_X, grad_input = convolution_backward(grad_out, X, conv_weight)\n        return grad_X, grad_input, None, None, None, None, None"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The next step is to wrap our functional variant in a stateful\n`nn.Module`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch.nn as nn\nimport math\n\nclass FusedConvBN(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, exp_avg_factor=0.1,\n                 eps=1e-3, device=None, dtype=None):\n        super(FusedConvBN, self).__init__()\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        # Conv parameters\n        weight_shape = (out_channels, in_channels, kernel_size, kernel_size)\n        self.conv_weight = nn.Parameter(torch.empty(*weight_shape, **factory_kwargs))\n        # Batch norm parameters\n        num_features = out_channels\n        self.num_features = num_features\n        self.eps = eps\n        # Initialize\n        self.reset_parameters()\n\n    def forward(self, X):\n        return FusedConvBN2DFunction.apply(X, self.conv_weight, self.eps)\n\n    def reset_parameters(self) -> None:\n        nn.init.kaiming_uniform_(self.conv_weight, a=math.sqrt(5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Use gradcheck to validate the correctness of our backward formula\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)\nX = torch.rand(2, 3, 4, 4, requires_grad=True, dtype=torch.double)\ntorch.autograd.gradcheck(FusedConvBN2DFunction.apply, (X, weight))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Testing out our new Layer\nUse FusedConvBN to train a basic network\nThe code below is after some light modifications to the example here:\nhttps://github.com/pytorch/examples/tree/master/mnist\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch.optim as optim\nfrom torchvision import datasets, transforms\nfrom torch.optim.lr_scheduler import StepLR\n\n# Record memory allocated at the end of the forward pass\nmemory_allocated = [[],[]]\n\nclass Net(nn.Module):\n    def __init__(self, fused=True):\n        super(Net, self).__init__()\n        self.fused = fused\n        if fused:\n            self.convbn1 = FusedConvBN(1, 32, 3)\n            self.convbn2 = FusedConvBN(32, 64, 3)\n        else:\n            self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)\n            self.bn1 = nn.BatchNorm2d(32, affine=False, track_running_stats=False)\n            self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)\n            self.bn2 = nn.BatchNorm2d(64, affine=False, track_running_stats=False)\n        self.fc1 = nn.Linear(9216, 128)\n        self.dropout = nn.Dropout(0.5)\n        self.fc2 = nn.Linear(128, 10)\n\n    def forward(self, x):\n        if self.fused:\n            x = self.convbn1(x)\n        else:\n            x = self.conv1(x)\n            x = self.bn1(x)\n        F.relu_(x)\n        if self.fused:\n            x = self.convbn2(x)\n        else:\n            x = self.conv2(x)\n            x = self.bn2(x)\n        F.relu_(x)\n        x = F.max_pool2d(x, 2)\n        F.relu_(x)\n        x = x.flatten(1)\n        x = self.fc1(x)\n        x = self.dropout(x)\n        F.relu_(x)\n        x = self.fc2(x)\n        output = F.log_softmax(x, dim=1)\n        if fused:\n            memory_allocated[0].append(torch.cuda.memory_allocated())\n        else:\n            memory_allocated[1].append(torch.cuda.memory_allocated())\n        return output\n\ndef train(model, device, train_loader, optimizer, epoch):\n    model.train()\n    for batch_idx, (data, target) in enumerate(train_loader):\n        data, target = data.to(device), target.to(device)\n        optimizer.zero_grad()\n        output = model(data)\n        loss = F.nll_loss(output, target)\n        loss.backward()\n        optimizer.step()\n        if batch_idx % 2 == 0:\n            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n                epoch, batch_idx * len(data), len(train_loader.dataset),\n                100. * batch_idx / len(train_loader), loss.item()))\n\ndef test(model, device, test_loader):\n    model.eval()\n    test_loss = 0\n    correct = 0\n    # Use inference mode instead of no_grad, for free improved test-time performance\n    with torch.inference_mode():\n        for data, target in test_loader:\n            data, target = data.to(device), target.to(device)\n            output = model(data)\n            # sum up batch loss\n            test_loss += F.nll_loss(output, target, reduction='sum').item()\n            # get the index of the max log-probability\n            pred = output.argmax(dim=1, keepdim=True)\n            correct += pred.eq(target.view_as(pred)).sum().item()\n\n    test_loss /= len(test_loader.dataset)\n\n    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n        test_loss, correct, len(test_loader.dataset),\n        100. * correct / len(test_loader.dataset)))\n\nuse_cuda = torch.cuda.is_available()\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\ntrain_kwargs = {'batch_size': 2048}\ntest_kwargs = {'batch_size': 2048}\n\nif use_cuda:\n    cuda_kwargs = {'num_workers': 1,\n                   'pin_memory': True,\n                   'shuffle': True}\n    train_kwargs.update(cuda_kwargs)\n    test_kwargs.update(cuda_kwargs)\n\ntransform = transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Normalize((0.1307,), (0.3081,))\n])\ndataset1 = datasets.MNIST('../data', train=True, download=True,\n                          transform=transform)\ndataset2 = datasets.MNIST('../data', train=False,\n                          transform=transform)\ntrain_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)\ntest_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## A Comparison of Memory Usage\nIf cuda is enabled, print out memory usage for both `fused=True` and `fused=False`\nFor an example run on RTX 3070, CuDNN 8.0.5: fused peak memory: 1.56GB,\nunfused peak memory: 2.68GB\n\nIt is important to note that the *peak* memory usage for this model may vary depending\nthe specific CuDNN convolution algorithm used. For shallower models, it\nmay be possible for the peak memory allocated of the fused model to exceed\nthat of the unfused model! This is because the memory allocated to compute\ncertain CuDNN convolution algorithms can be high enough to \"hide\" the typical peak\nyou would expect to be near the start of the backward pass.\n\nFor this reason, we also record and display the memory allocated at the end\nof the forward pass as an approximation, and to demonstrate that we indeed\nallocate one fewer buffer per fused conv-bn pair.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from statistics import mean\n\ntorch.backends.cudnn.enabled = True\n\nif use_cuda:\n    peak_memory_allocated = []\n\n    for fused in (True, False):\n        torch.manual_seed(123456)\n\n        model = Net(fused=fused).to(device)\n        optimizer = optim.Adadelta(model.parameters(), lr=1.0)\n        scheduler = StepLR(optimizer, step_size=1, gamma=0.7)\n\n        for epoch in range(1):\n            train(model, device, train_loader, optimizer, epoch)\n            test(model, device, test_loader)\n            scheduler.step()\n        peak_memory_allocated.append(torch.cuda.max_memory_allocated())\n        torch.cuda.reset_peak_memory_stats()\n    print(\"CuDNN version:\", torch.backends.cudnn.version())\n    print()\n    print(\"Peak memory allocated:\")\n    print(f\"fused: {peak_memory_allocated[0]/1024**3:.2f}GB, unfused: {peak_memory_allocated[1]/1024**3:.2f}GB\")\n    print(\"Memory allocated at end of forward pass:\")\n    print(f\"fused: {mean(memory_allocated[0])/1024**3:.2f}GB, unfused: {mean(memory_allocated[1])/1024**3:.2f}GB\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}